package deeplearn;

import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.ops.NDRandom;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Map;

/**
 * 说明：
 * 博客：https://blog.csdn.net/dong_lxkm/article/details/103125742
 *     1、dl4j并没有提供像keras那样冻结某些层参数的方法，这里采用设置learningrate为0的方法，来冻结某些层的参数
 *
 *     2、这个的更新器，用的是sgd，不能用其他的（比方说Adam、Rmsprop），因为这些自适应更新器会考虑前面batch的梯度作为本次更新的梯度，达不到不更新参数的目的
 *
 *     3、这里用了StackVertex，沿着第一维合并张量，也就是合并真实数据样本和Generator产生的数据样本，共同训练Discriminator
 *
 *     4、训练过程中多次update   Discriminator的参数，以便量出最大距离，让后更新Generator一次
 *
 *     5、进行10w次迭代
		6、数据被下载到C:\Users\Administrator\.deeplearning4j\data\MNIST

 		拆成两个  net
 */
public class WGAN_GP_S {

	private static JFrame frame;
	private static JPanel panel;

	static int height = 28; // 输入图像高度
	static int width = 28; // 输入图像宽度
	static int channels = 1; // 输入图像通道数
	public static final int batch = 2;
	public static final int latent_dim  =10;

	static double lr = 0.01;
	static String modelg = "F:/face/wgang.zip";
	static String modeld = "F:/face/wgand.zip";
	public static void main(String[] args) throws Exception {



		final GraphBuilder generatorModel = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(new LossWasserstein()).nIn(128).nOut(1)
						.activation(Activation.GELU).build(), "d3")
				.setOutputs("out");


		final GraphBuilder criticModel = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(new LossWasserstein()).nIn(128).nOut(1)
						.activation(Activation.GELU).build(), "d3")
				.setOutputs("out");


		/*ComputationGraph generator = new ComputationGraph(generatorModel.build());
		ComputationGraph critic  = new ComputationGraph(criticModel.build());

		generator.feedForward(false);

		 // Image input (real sample)
		INDArray real_img = Nd4j.rand(new NormalDistribution(), new long[]{channels, height, width});
		// Noise input
		INDArray z_disc = Nd4j.rand(batch,latent_dim);
         // Generate image based of noise (fake sample)
		INDArray fake_img = generator.feedForward(z_disc, false).get("g3");

		// Discriminator determines validity of the real and fake images

		INDArray fake  = critic.feedForward(fake_img, false).get("out");
		INDArray valid  = critic.feedForward(real_img, false).get("out");
		INDArray interpolated_img = RandomWeightedAverage(new INDArray[]{real_img, fake_img});
		// Determine validity of weighted sample
		INDArray validity_interpolated = critic.feedForward(interpolated_img, false).get("out");
		// Use Python partial to provide loss function with additional
        // 'averaged_samples' argument
		INDArray partial_gp_loss = gradient_penalty_loss(critic,true,true,interpolated_img);*/


		ComputationGraph generator = new ComputationGraph(generatorModel.build());
		ComputationGraph critic = new ComputationGraph(criticModel.build());


		generator.init();
		critic.init();

		System.out.println(generator.summary());
		System.out.println(critic.summary());

		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		generator.setListeners(new ScoreIterationListener(100));
		critic.setListeners(new ScoreIterationListener(100));


		DataSetIterator train = new MnistDataSetIterator(batch, true, 12345);
		DataNormalization scalerA = new ImagePreProcessingScaler(-1, 1);
		scalerA.fit(train);
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		INDArray labelD = Nd4j.vstack(Nd4j.ones(batch, 1).mul(-1), Nd4j.ones(batch, 1));
 
		INDArray labelG = Nd4j.ones(2*batch, 1).mul(-1);
 
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { batch, 10 });
			MultiDataSet dataSetD = new MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelD });
			for(int m=0;m<10;m++){
				trainD(critic, dataSetD);
			}
			updateB2A(critic,generator);

			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] {z,trueExp},
					new INDArray[] { labelG });
			trainG(generator, dataSetG);
			updateA2B(critic,generator);
			/*long[] shape = indArray.shape();
			for(long l :shape){
				System.out.println(l);
			}*/
			Map<String, INDArray> map = generator.feedForward(
					new INDArray[] { Nd4j.rand(new NormalDistribution(),new long[] { batch, 10 }), trueExp }, false);
			INDArray indArray = map.get("g3");// .reshape(20,28,28);
			long[] shapes = indArray.shape();
			/*for(long l : shapes){
				System.out.println(l);
			}*/
			if (i % 10 == 0) {
				long size = indArray.size(0);
				INDArray[] samples = new INDArray[(int)size];
				for (int j = 0; j < 1; j++) {
					samples[j] = indArray.getRow(j);
				}
				visualize(samples);
			}

			if (i % 100 == 0) {
				generator.save(new File(modeld), true);
				critic.save(new File(modelg), true);
			}
		}
 
	}

	private static void updateA2B(ComputationGraph ganA, ComputationGraph ganB) {
		//System.out.println("参数：A->B");
		ganB.getLayer("g1").setParams(ganA.getLayer("g1").params());
		ganB.getLayer("g2").setParams(ganA.getLayer("g2").params());
		ganB.getLayer("g3").setParams(ganA.getLayer("g3").params());
	}
	private static void updateB2A(ComputationGraph ganA, ComputationGraph ganB) {
		//System.out.println("参数：B->A");
		ganA.getLayer("d1").setParams(ganB.getLayer("d1").params());
		ganA.getLayer("d2").setParams(ganB.getLayer("d2").params());
		ganA.getLayer("d3").setParams(ganB.getLayer("d3").params());
		ganA.getLayer("out").setParams(ganB.getLayer("out").params());

	}

	public static INDArray gradient_penalty_loss(ComputationGraph model,boolean y_true,boolean y_pred,INDArray averaged_samples) {

        //Computes gradient penalty based on prediction and weighted real / fake samples
		Gradient gradients = model.backpropGradient(averaged_samples);
		INDArray gradient = gradients.gradient();
		//compute the euclidean norm by squaring ...
		INDArray gradients_sqr = gradient.mul(gradient);
		//  ... summing over the rows ...
		INDArray axis = Nd4j.arange(1, gradients_sqr.shape().length);
		for(int i=0;i<axis.rows();i++){
			gradients_sqr = Nd4j.sum(gradients_sqr,axis.getInt(i));
		}
		INDArray gradients_sqr_sum = gradients_sqr;
		//  ... and sqrt
		INDArray gradient_l2_norm = Transforms.sqrt(gradients_sqr_sum);
		//compute lambda * (1 - ||grad||)^2 still for each single sample
		INDArray gradient_l2_norm_sqr = Nd4j.ones(gradient_l2_norm.shape()).sub(gradient_l2_norm);
		INDArray gradient_penalty = gradient_l2_norm_sqr.mul(gradient_l2_norm_sqr);
		//return the mean as loss over all the batch samples
		return Nd4j.mean(gradient_penalty);
	}

	public static INDArray RandomWeightedAverage(INDArray[] inputs){
		INDArray alpha = Nd4j.rand(32, 1, 1, 1);// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
		return (alpha.mul(inputs[0])).add((Nd4j.ones(alpha.shape()).sub(alpha)).mul(inputs[1]));
	}

 	// 判别模型  D(x)
	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", 0);
		net.setLearningRate("g2", 0);
		net.setLearningRate("g3", 0);
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}
	//生成模型 g(z)
	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}

	private static void visualize(INDArray[] samples) {
		if (frame == null) {
			frame = new JFrame();
			frame.setTitle("Viz");
			frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
			frame.setLayout(new BorderLayout());

			panel = new JPanel();

			panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
			frame.add(panel, BorderLayout.CENTER);
			frame.setVisible(true);
		}

		panel.removeAll();

		for (INDArray sample : samples) {
			if(sample == null || sample.size(0) == 0){
				continue;
			}
			panel.add(imageFromINDArray(sample));
		}

		frame.revalidate();
		frame.pack();
	}

	private static JLabel getImage(INDArray tensor) {
		BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
		for (int i = 0; i < 784; i++) {
			bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * tensor.getDouble(i)));
		}
		ImageIcon orig = new ImageIcon(bi);
		Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
				Image.SCALE_DEFAULT);
		ImageIcon scaled = new ImageIcon(imageScaled);

		return new JLabel(scaled);
	}
	private static JLabel imageFromINDArray(INDArray array) {
		array = array.reshape(28, 28);
		long[] shape = array.shape();
		int height = (int)shape[0];
		int width = (int)shape[1];
		BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
		for (int x = 0; x < width; x++) {
			for (int y = 0; y < height; y++) {
				System.out.println(array.getDouble(0, 0, y, x));
				int gray = (int) ((array.getDouble(0, 0, y, x)) * 127.5 + 127.5);

				// handle out of bounds pixel values
				gray = Math.min(gray, 255);
				gray = Math.max(gray, 0);

				image.getRaster().setSample(x, y, 0, gray);
			}
		}
		ImageIcon orig = new ImageIcon(image);
		Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
				Image.SCALE_DEFAULT);
		ImageIcon scaled = new ImageIcon(imageScaled);
		return new JLabel(scaled);
	}
}